# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import List, Optional, Tuple, Union

import mujoco
import numpy as np
import scipy
import termcolor
import tqdm
from PIL import Image as im
from PIL import ImageOps

# TODO: b/288149332 - Remove once USD Python Binding works well with pytype.
# pytype: disable=module-attr
from pxr import Sdf, Usd, UsdGeom

import robocasa.utils.usd.component as component_module
import robocasa.utils.usd.shapes as shapes_module


class USDExporter:
    def __init__(
        self,
        model: mujoco.MjModel,
        height: int = 480,
        width: int = 480,
        max_geom: int = 10000,
        output_directory_name: str = "mujoco_usdpkg",
        output_directory_root: str = "./",
        light_intensity: int = 10000,
        camera_names: Optional[List[str]] = None,
        specialized_materials_file: Optional[str] = None,
        verbose: bool = True,
    ):
        """Initializes a new USD Exporter

        Args:
            model: an MjModel instance.
            height: image height in pixels.
            width: image width in pixels.
            max_geom: Optional integer specifying the maximum number of geoms that
              can be rendered in the same scene. If None this will be chosen
              automatically based on the estimated maximum number of renderable
              geoms in the model.
            output_directory_name: name of root directory to store outputted frames
              and assets generated by the USD renderer.
            output_directory_root: path to root directory storing generated frames
              and assets by the USD renderer.
            verbose: decides whether to print updates.
        """

        buffer_width = model.vis.global_.offwidth
        buffer_height = model.vis.global_.offheight

        if width > buffer_width:
            raise ValueError(
                f"""
                Image width {width} > framebuffer width {buffer_width}. Either reduce the image
                width or specify a larger offscreen framebuffer in the model XML using the
                clause:
                <visual>
                <global offwidth="my_width"/>
                </visual>""".lstrip()
            )

        if height > buffer_height:
            raise ValueError(
                f"""
                Image height {height} > framebuffer height {buffer_height}. Either reduce the
                image height or specify a larger offscreen framebuffer in the model XML using
                the clause:
                <visual>
                <global offheight="my_height"/>
                </visual>""".lstrip()
            )

        self.model = model
        self.height = height
        self.width = width
        self.max_geom = max_geom
        self.output_directory_name = output_directory_name
        self.output_directory_root = output_directory_root
        self.light_intensity = light_intensity
        self.camera_names = camera_names
        self.specialized_materials_file = specialized_materials_file
        self.verbose = verbose

        self.frame_count = 0  # maintains how many times we have saved the scene
        self.updates = 0

        self.geom_name2usd = {}

        # initializing rendering requirements
        self.renderer = mujoco.Renderer(model, height, width, max_geom)
        self._initialize_usd_stage()
        self._scene_option = mujoco.MjvOption()  # using default scene option

        # initializing output_directories
        self._initialize_output_directories()

        # loading required textures for the scene
        self._load_textures()

    @property
    def usd(self):
        return self.stage.GetRootLayer().ExportToString()

    @property
    def scene(self):
        return self.renderer.scene

    def _initialize_usd_stage(self):
        self.stage = Usd.Stage.CreateInMemory()
        UsdGeom.SetStageUpAxis(self.stage, UsdGeom.Tokens.z)
        self.stage.SetStartTimeCode(0)
        # add as user input
        self.stage.SetTimeCodesPerSecond(60.0)

        default_prim = UsdGeom.Xform.Define(self.stage, Sdf.Path("/World")).GetPrim()
        self.stage.SetDefaultPrim(default_prim)

    def _initialize_output_directories(self):
        self.output_directory_path = os.path.join(
            self.output_directory_root, self.output_directory_name
        )
        if not os.path.exists(self.output_directory_path):
            os.makedirs(self.output_directory_path)

        self.frames_directory = os.path.join(self.output_directory_path, "frames")
        if not os.path.exists(self.frames_directory):
            os.makedirs(self.frames_directory)

        self.assets_directory = os.path.join(self.output_directory_path, "assets")
        if not os.path.exists(self.assets_directory):
            os.makedirs(self.assets_directory)

        if self.verbose:
            print(
                termcolor.colored(
                    "Writing output frames and assets to"
                    f" {self.output_directory_path}",
                    "green",
                )
            )

    def update_scene(
        self,
        data: mujoco.MjData,
        scene_option: Optional[mujoco.MjvOption] = None,
    ):
        """Updates the scene with latest sim data

        Args:
            data: structure storing current simulation state
            scene_option: we use this to determine which geom groups to activate
        """

        self.frame_count += 1

        scene_option = scene_option or self._scene_option

        # update the mujoco renderer
        self.renderer.update_scene(data, scene_option=scene_option)

        # TODO: update scene options
        if self.updates == 0:
            self._initialize_usd_stage()

            self._load_lights()
            self._load_cameras()

        self._update_geoms()
        self._update_lights()
        self._update_cameras(data, scene_option=scene_option)

        self.updates += 1

    def _load_textures(self):
        # TODO: remove code once added internally to mujoco
        data_adr = 0
        self.texture_files = []
        for texture_id in tqdm.tqdm(range(self.model.ntex)):
            texture_height = self.model.tex_height[texture_id]
            texture_width = self.model.tex_width[texture_id]
            pixels = 3 * texture_height * texture_width
            img = im.fromarray(
                self.model.tex_rgb[data_adr : data_adr + pixels].reshape(
                    texture_height, texture_width, 3
                )
            )
            img = ImageOps.flip(img)

            texture_file_name = f"texture_{texture_id}.png"

            img.save(os.path.join(self.assets_directory, texture_file_name))

            relative_path = os.path.relpath(
                self.assets_directory, self.frames_directory
            )
            img_path = os.path.join(
                relative_path, texture_file_name
            )  # relative path, TODO: switch back to this

            self.texture_files.append(img_path)

            data_adr += pixels

        if self.verbose:
            print(
                termcolor.colored(
                    f"Completed writing {self.model.ntex} textures to"
                    f" {self.assets_directory}",
                    "green",
                )
            )

    def _load_geom(self, geom: mujoco.MjvGeom):

        geom_name = self._get_geom_name(geom.objtype, geom.objid)

        assert geom_name not in self.geom_name2usd

        texture_file = self.texture_files[geom.texid] if geom.texid != -1 else None

        if geom.type == mujoco.mjtGeom.mjGEOM_MESH:
            usd_geom = component_module.USDMesh(
                stage=self.stage,
                model=self.model,
                geom=geom,
                obj_name=geom_name,
                dataid=self.model.geom_dataid[geom.objid],
                rgba=geom.rgba,
                texture_file=texture_file,
            )
        else:
            mesh_config = shapes_module.mesh_config_generator(
                name=geom_name, geom_type=geom.type, size=geom.size
            )
            usd_geom = component_module.USDPrimitiveMesh(
                mesh_config=mesh_config,
                stage=self.stage,
                geom=geom,
                obj_name=geom_name,
                rgba=geom.rgba,
                texture_file=texture_file,
            )

        self.geom_name2usd[geom_name] = usd_geom

    def _update_geoms(self):

        geom_names = set(self.geom_name2usd.keys())

        for i in range(self.scene.ngeom):
            geom = self.scene.geoms[i]
            geom_name = mujoco.mj_id2name(self.model, geom.objtype, geom.objid)
            if not geom_name:
                geom_name = "None"
            geom_name += f"_{geom.objid}"

        # iterate through all geoms in the scene and makes update
        for i in range(self.scene.ngeom):
            geom = self.scene.geoms[i]
            geom_name = self._get_geom_name(geom.objtype, geom.objid)

            if geom_name not in self.geom_name2usd:
                self._load_geom(geom)
                if self.geom_name2usd[geom_name]:
                    self.geom_name2usd[geom_name].update_visibility(False, 0)

            if self.geom_name2usd[geom_name]:
                self.geom_name2usd[geom_name].update(
                    pos=geom.pos,
                    mat=geom.mat,
                    visible=geom.rgba[3] > 0,
                    frame=self.updates,
                )
            if geom_name in geom_names:
                geom_names.remove(geom_name)

        for geom_name in geom_names:
            if self.geom_name2usd[geom_name]:
                self.geom_name2usd[geom_name].update_visibility(False, self.updates)

    def _load_lights(self):
        # initializes an usd light object for every light in the scene
        self.usd_lights = []
        for i in range(self.scene.nlight):
            light = self.scene.lights[i]
            if not np.allclose(light.pos, [0, 0, 0]):
                self.usd_lights.append
                (component_module.USDSphereLight(stage=self.stage, obj_name=str(i)))
            else:
                self.usd_lights.append(None)

    def _update_lights(self):
        for i in range(self.scene.nlight):
            light = self.scene.lights[i]

            if np.allclose(light.pos, [0, 0, 0]):
                continue

            if i >= len(self.usd_lights) or self.usd_lights[i] is None:
                continue

            self.usd_lights[i].update(
                pos=light.pos,
                intensity=self.light_intensity,
                color=light.diffuse,
                frame=self.updates,
            )

    def _load_cameras(self):
        self.usd_cameras = []
        if self.camera_names is not None:
            for name in self.camera_names:
                self.usd_cameras.append(
                    component_module.USDCamera(stage=self.stage, obj_name=name)
                )

    def _update_cameras(
        self,
        data: mujoco.MjData,
        scene_option: Optional[mujoco.MjvOption] = None,
    ):
        for i in range(len(self.usd_cameras)):

            camera = self.usd_cameras[i]
            camera_name = self.camera_names[i]

            self.renderer.update_scene(
                data, scene_option=scene_option, camera=camera_name
            )

            avg_camera = mujoco.mjv_averageCamera(
                self.scene.camera[0], self.scene.camera[1]
            )

            forward = avg_camera.forward
            up = avg_camera.up
            right = np.cross(forward, up)

            R = np.eye(3)
            R[:, 0] = right
            R[:, 1] = up
            R[:, 2] = -forward

            camera.update(cam_pos=avg_camera.pos, cam_mat=R, frame=self.updates)

    def add_light(
        self,
        pos: List[float],
        intensity: int,
        objid: int,
        radius: Optional[float] = 1.0,
        color: Optional[np.ndarray] = np.array([0.3, 0.3, 0.3]),
        obj_name: Optional[str] = "light_1",
        light_type: Optional[str] = "sphere",
    ):

        if light_type == "sphere":
            new_light = component_module.USDSphereLight(
                stage=self.stage, obj_name=str(objid), radius=radius
            )

            new_light.update(
                pos=np.array(pos), intensity=intensity, color=color, frame=0
            )
        elif light_type == "dome":
            new_light = component_module.USDDomeLight(
                stage=self.stage, obj_name=str(objid)
            )

            new_light.update(intensity=intensity, color=color, frame=0)

    def add_camera(
        self,
        pos: List[float],
        rotation_xyz: List[float],
        objid: int,
        obj_name: Optional[str] = "camera_1",
    ):
        new_camera = component_module.USDCamera(stage=self.stage, obj_name=str(objid))

        r = scipy.spatial.transform.Rotation.from_euler(
            "xyz", rotation_xyz, degrees=True
        )
        new_camera.update(cam_pos=np.array(pos), cam_mat=r.as_matrix(), frame=0)

    def save_scene(self, filetype: str = "usd"):
        assert filetype in ["usd", "usda", "usdc"]
        self.stage.SetEndTimeCode(self.frame_count)
        self.stage.Export(
            f"{self.output_directory_root}/{self.output_directory_name}/frames/frame_{self.frame_count}_.{filetype}"
        )
        if self.verbose:
            print(
                termcolor.colored(
                    f"Completed writing frame_{self.frame_count}.{filetype}", "green"
                )
            )

    def _get_geom_name(self, objtype, objid):
        geom_name = mujoco.mj_id2name(self.model, objtype, objid)
        if not geom_name:
            geom_name = "None"
        geom_name += f"_{objid}"
        return geom_name
